-
Notifications
You must be signed in to change notification settings - Fork 427
Add option for selective op AC to filter mm shapes based on fqn #1380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
9e3b49b
to
3c4d97d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great idea!
Since this feature is advanced, could you also help test if the behavior is expected?
It seems this feature does not require distributed, so maybe we can add a unit test file in
https://github.com/pytorch/torchtitan/tree/main/tests/unit_tests
But if it doesn't make sense, feel free to do it in the way you prefer.
torchtitan/config_manager.py
Outdated
@@ -487,6 +487,20 @@ class ActivationCheckpoint: | |||
'int' (e.g., 2) for every nth layer, or 'op' for op level ac. | |||
""" | |||
|
|||
selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field( | |||
default_factory=lambda: [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems good enough to
default_factory=lambda: [] | |
default_factory=list |
or we can default to ["moe.router.gate"]
so that we don't need to define it in a lot of tomls.
O/w could you please also update the tomls in https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/llama4/train_configs
and
https://github.com/pytorch/torchtitan/tree/main/torchtitan/models/deepseek_v3/train_configs
if ( | ||
fqn | ||
not in ac_config.selective_op_ac_force_recompute_mm_shapes_by_fqns | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that in float8, we also filter by fqns, in which we are doing reversely
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/utils.py#L25
I think one reason could be that the filter over there is applied to the whole model, so one fqn can help map to multiple layers / modules.
I think for AC there's not that much difference between the two. The benefit of doing it the other way may be users don't need to specify accurately the full relative fqn within the AC region. E.g. "router.gate" would also work.
I don't have a strong preference, but maybe let's be consistent with float8 if you don't have strong preference either.
torchtitan/config_manager.py
Outdated
@@ -487,6 +487,20 @@ class ActivationCheckpoint: | |||
'int' (e.g., 2) for every nth layer, or 'op' for op level ac. | |||
""" | |||
|
|||
selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should prefer a shorter name over how accurate its meaning is. How about per_op_sac_filter_fqns
? Most users shouldn't really care about the details of implementation; if some users do, they can check the helper message and implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. To help reduce the cognitive load of parsing the config file for the average user who I agree won't care about the impl, would it help if the default already include any moe router fqns for TorchTitan models per your other suggestion? This means most configs won't need to contain it at all, so most users won't see it and the advanced users using it will still benefit from a more explicit name.
I think this is consistent with most users already not being aware of what the per-op sac policy is at all, although we could potentially refactor things such that we have specific policies like policy="compute_intensive_excluding_every_other_matmul"
3c4d97d
to
bfb3a32
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I left some comments.
The unit test looks awesome, although IIRC our CPU unit test can't handle GPU tests
https://github.com/pytorch/torchtitan/blob/main/.github/workflows/unit_test_cpu.yaml
Do you think we can test AC on a CPU? If not, we can land the current one for now, and I'll try to find a way to run GPU unit tests later.
@@ -27,7 +27,7 @@ | |||
SequenceParallel, | |||
) | |||
|
|||
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP | |||
from torchtitan.config_manager import ActivationCheckpoint, JobConfig, TORCH_DTYPE_MAP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe ActivationCheckpoint as ACConfig
|
||
def test_correctness(self): | ||
if not torch.cuda.is_available(): | ||
raise unittest.SkipTest("CUDA is unavailable") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe noob question:
Does AC require GPU to run? My intuition was it should be able to run on CPU.
class TestApplyAC(unittest.TestCase): | ||
def test_flops(self): | ||
if not torch.cuda.is_available(): | ||
raise unittest.SkipTest("CUDA is unavailable") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar question: Does AC / FlopCounterMode
require GPU to run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AC and FlopCounterMode should not require GPU, but peak memory stats does. I can refactor out the flop counter test so that it runs if we are only able to run CPU-only.
@@ -237,7 +237,9 @@ def apply_tp( | |||
} | |||
|
|||
|
|||
def _apply_ac_to_transformer_block(module: nn.Module, ac_config): | |||
def _apply_ac_to_transformer_block( | |||
module: nn.Module, ac_config: ActivationCheckpoint, base_fqn: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since in torchtitan we only apply AC at transformer block level, I feel the arg base_fqn
is less needed, in the sense that there should be rare cases where user apply per op SAC, but only wants to filter router.gate matmul in layer 1 but not layer 2.
Most use cases would be per_op_sac_force_recompute_mm_shapes_by_fqns = ["moe.router.gate"]
and moe.router.gate
should be already in module_fqn
without base_fqn
.
If that's the case, I think it's not necessary to add this field. Let me know if you think otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is necessary but I wanted to be consistent with float8's fqn matching if that is based on the entire model's fqn. We can also avoid a line of documentation mentioning that fqns are actually relative to TransformerBlock. Let me know if you'd still like it changed, I think either is fine, but slightly preferred this direciton.
bfb3a32
to
c2cdb20
Compare
c2cdb20
to
b27dae7
Compare
Also see discussion in #1372
This PR: